#!/bin/bash
# This script is used to make old version triton work on generating ptx code up to version 7.8
# See https://github.com/openai/triton/blob/8650b4d1cbc750d659156e2c17a058736614827b/lib/driver/llvm.cc#L149
set -e

mkdir -p $HOME/.triton/

[ $HOME/.triton/resolve_ptx_version.so -nt $0 ] || (echo '
#include <stdexcept>
namespace triton {
namespace driver {

int vptx(int version) {
    // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#release-notes
    if (version >= 11080) return 78;
    if (version >= 11070) return 77;
    if (version >= 11060) return 76;
    if (version >= 11050) return 75;
    if (version >= 11040) return 74;
    throw std::runtime_error("Triton requires CUDA 11.4+");
}

}
}' \
| g++ -x c++ -fPIC -shared -o $HOME/.triton/resolve_ptx_version.so -)

[ -z "$*" ] || env LD_PRELOAD=$LD_PRELOAD:$HOME/.triton/resolve_ptx_version.so "$@"